{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "AOpGoE2T-YXS"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\").\n",
        "\n",
        "# Neural Machine Translation with Attention\n",
        "\n",
        "\n",
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "    \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/r2/tutorials/sequences/_nmt.ipynb\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003e\n",
        "    Run in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/sequences/_nmt.ipynb\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003e\n",
        "    View source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "CiwtNgENbx2g"
      },
      "source": [
        "# This notebook is still under construction! Please come back later.\n",
        "\n",
        "\n",
        "This notebook trains a sequence to sequence (seq2seq) model for Spanish to English translation using TF 2.0 APIs. This is an advanced example that assumes some knowledge of sequence to sequence models.\n",
        "\n",
        "After training the model in this notebook, you will be able to input a Spanish sentence, such as *\"¿todavia estan en casa?\"*, and return the English translation: *\"are you still at home?\"*\n",
        "\n",
        "The translation quality is reasonable for a toy example, but the generated attention plot is perhaps more interesting. This shows which parts of the input sentence has the model's attention while translating:\n",
        "\n",
        "\u003cimg src=\"https://tensorflow.org/images/spanish-english.png\" alt=\"spanish-english attention plot\"\u003e\n",
        "\n",
        "Note: This example takes approximately 10 mintues to run on a single P100 GPU."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "tnxXKDjq3jEL"
      },
      "outputs": [],
      "source": [
        "import collections\n",
        "import io\n",
        "import itertools\n",
        "import os\n",
        "import random\n",
        "import re\n",
        "import time\n",
        "import unicodedata\n",
        "\n",
        "import numpy as np\n",
        "\n",
        "import tensorflow as tf\n",
        "assert tf.__version__.startswith('2')\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "print(tf.__version__)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "wfodePkj3jEa"
      },
      "source": [
        "## Download and prepare the dataset\n",
        "\n",
        "We'll use a language dataset provided by http://www.manythings.org/anki/. This dataset contains language translation pairs in the format:\n",
        "\n",
        "```\n",
        "May I borrow this book?\t¿Puedo tomar prestado este libro?\n",
        "```\n",
        "\n",
        "There are a variety of languages available, but we'll use the English-Spanish dataset. For convenience, we've hosted a copy of this dataset on Google Cloud, but you can also download your own copy. After downloading the dataset, here are the steps we'll take to prepare the data:\n",
        "\n",
        "1. Clean the sentences by removing special characters.\n",
        "1. Add a *start* and *end* token to each sentence.\n",
        "1. Create a word index and reverse word index (dictionaries mapping from word → id and id → word).\n",
        "1. Pad each sentence to a maximum length."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "kRVATYOgJs1b"
      },
      "outputs": [],
      "source": [
        "# TODO(brianklee): This preprocessing should ideally be implemented in TF\n",
        "# because preprocessing should be exported as part of the SavedModel.\n",
        "\n",
        "# Converts the unicode file to ascii\n",
        "# https://stackoverflow.com/a/518232/2809427\n",
        "def unicode_to_ascii(s):\n",
        "  return ''.join(c for c in unicodedata.normalize('NFD', s)\n",
        "                 if unicodedata.category(c) != 'Mn')\n",
        "\n",
        "START_TOKEN = u'\u003cstart\u003e'\n",
        "END_TOKEN = u'\u003cend\u003e'\n",
        "\n",
        "def preprocess_sentence(w):\n",
        "  # remove accents; lowercase everything\n",
        "  w = unicode_to_ascii(w.strip()).lower()\n",
        "\n",
        "  # creating a space between a word and the punctuation following it\n",
        "  # eg: \"he is a boy.\" =\u003e \"he is a boy .\"\n",
        "  # https://stackoverflow.com/a/3645931/3645946\n",
        "  w = re.sub(r'([?.!,¿])', r' \\1 ', w)\n",
        "\n",
        "  # replacing everything with space except (a-z, '.', '?', '!', ',')\n",
        "  w = re.sub(r'[^a-z?.!,¿]+', ' ', w)\n",
        "\n",
        "  # adding a start and an end token to the sentence\n",
        "  # so that the model know when to start and stop predicting.\n",
        "  w = '\u003cstart\u003e ' + w + ' \u003cend\u003e'\n",
        "  return w\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "PbX9r8blNIUu"
      },
      "outputs": [],
      "source": [
        "en_sentence = u\"May I borrow this book?\"\n",
        "sp_sentence = u\"¿Puedo tomar prestado este libro?\"\n",
        "print(preprocess_sentence(en_sentence))\n",
        "print(preprocess_sentence(sp_sentence))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "RNWJCJIwPSZp"
      },
      "source": [
        "Training on the complete dataset of \u003e100,000 sentences will take a long time. To train faster, we can limit the size of the dataset (of course, translation quality degrades with less data).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "OHn4Dct23jEm"
      },
      "outputs": [],
      "source": [
        "def load_anki_data(num_examples=None):\n",
        "  # Download the file\n",
        "  path_to_zip = tf.keras.utils.get_file(\n",
        "      'spa-eng.zip', origin='http://download.tensorflow.org/data/spa-eng.zip',\n",
        "      extract=True)\n",
        "\n",
        "  path_to_file = os.path.dirname(path_to_zip) + '/spa-eng/spa.txt'\n",
        "  with io.open(path_to_file, 'rb') as f:\n",
        "    lines = f.read().decode('utf8').strip().split('\\n')\n",
        "\n",
        "  # Data comes as tab-separated strings; one per line.\n",
        "  eng_spa_pairs = [[preprocess_sentence(w) for w in line.split('\\t')] for line in lines]\n",
        "\n",
        "  # The translations file is ordered from shortest to longest, so slicing from\n",
        "  # the front will select the shorter examples. This also speeds up training.\n",
        "  if num_examples is not None:\n",
        "    eng_spa_pairs = eng_spa_pairs[:num_examples]\n",
        "  eng_sentences, spa_sentences = zip(*eng_spa_pairs)\n",
        "\n",
        "  eng_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')\n",
        "  spa_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')\n",
        "  eng_tokenizer.fit_on_texts(eng_sentences)\n",
        "  spa_tokenizer.fit_on_texts(spa_sentences)\n",
        "  return (eng_spa_pairs, eng_tokenizer, spa_tokenizer)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "mfI6wprNPSZs"
      },
      "outputs": [],
      "source": [
        "NUM_EXAMPLES = 30000\n",
        "sentence_pairs, english_tokenizer, spanish_tokenizer = load_anki_data(NUM_EXAMPLES)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "eAY9k49G3jE_"
      },
      "outputs": [],
      "source": [
        "# Turn our english/spanish pairs into TF Datasets by mapping words -\u003e integers.\n",
        "\n",
        "def make_dataset(eng_spa_pairs, eng_tokenizer, spa_tokenizer):\n",
        "  eng_sentences, spa_sentences = zip(*eng_spa_pairs)\n",
        "  eng_ints = eng_tokenizer.texts_to_sequences(eng_sentences)\n",
        "  spa_ints = spa_tokenizer.texts_to_sequences(spa_sentences)\n",
        "\n",
        "  padded_eng_ints = tf.keras.preprocessing.sequence.pad_sequences(\n",
        "      eng_ints, padding='post')\n",
        "  padded_spa_ints = tf.keras.preprocessing.sequence.pad_sequences(\n",
        "      spa_ints, padding='post')\n",
        "\n",
        "  dataset = tf.data.Dataset.from_tensor_slices((padded_eng_ints, padded_spa_ints))\n",
        "  return dataset\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "sEyV3Vd4PSZy"
      },
      "outputs": [],
      "source": [
        "# Train/test split\n",
        "train_size = int(len(sentence_pairs) * 0.8)\n",
        "random.shuffle(sentence_pairs)\n",
        "train_sentence_pairs, test_sentence_pairs = sentence_pairs[:train_size], sentence_pairs[train_size:]\n",
        "# Show length\n",
        "len(train_sentence_pairs), len(test_sentence_pairs)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "_vdlLeT9PSZ4"
      },
      "outputs": [],
      "source": [
        "_english, _spanish = train_sentence_pairs[0]\n",
        "_eng_ints, _spa_ints = english_tokenizer.texts_to_sequences([_english])[0], spanish_tokenizer.texts_to_sequences([_spanish])[0]\n",
        "print(\"Source language: \")\n",
        "print('\\n'.join('{:4d} ----\u003e {}'.format(i, word) for i, word in zip(_eng_ints, _english.split())))\n",
        "print(\"Target language: \")\n",
        "print('\\n'.join('{:4d} ----\u003e {}'.format(i, word) for i, word in zip(_spa_ints, _spanish.split())))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "4QILQkOs3jFG"
      },
      "outputs": [],
      "source": [
        "# Set up datasets\n",
        "BATCH_SIZE = 64\n",
        "\n",
        "train_ds = make_dataset(train_sentence_pairs, english_tokenizer, spanish_tokenizer)\n",
        "test_ds = make_dataset(test_sentence_pairs, english_tokenizer, spanish_tokenizer)\n",
        "train_ds = train_ds.shuffle(len(train_sentence_pairs)).batch(BATCH_SIZE, drop_remainder=True)\n",
        "test_ds = test_ds.batch(BATCH_SIZE, drop_remainder=True)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "YM69ROrxPSZ7"
      },
      "outputs": [],
      "source": [
        "print(\"Dataset outputs elements with shape ({}, {})\".format(\n",
        "    *train_ds.output_shapes))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "TNfHIF71ulLu"
      },
      "source": [
        "## Write the encoder and decoder model\n",
        "\n",
        "Here, we'll implement an encoder-decoder model with attention. The following diagram shows that each input word is assigned a weight by the attention mechanism which is then used by the decoder to predict the next word in the sentence.\n",
        "\n",
        "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_mechanism.jpg\" width=\"500\" alt=\"attention mechanism\"\u003e\n",
        "\n",
        "The input is put through an encoder model which gives us the encoder output of shape *(batch_size, max_length, hidden_size)* and the encoder hidden state of shape *(batch_size, hidden_size)*. \n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Pm0KGxEbPSaB"
      },
      "outputs": [],
      "source": [
        "ENCODER_SIZE = DECODER_SIZE = 1024\n",
        "EMBEDDING_DIM = 256\n",
        "MAX_OUTPUT_LENGTH = train_ds.output_shapes[1][1]\n",
        "\n",
        "def gru(units):\n",
        "  return tf.keras.layers.GRU(units,\n",
        "                             return_sequences=True,\n",
        "                             return_state=True,\n",
        "                             recurrent_activation='sigmoid',\n",
        "                             recurrent_initializer='glorot_uniform')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "nZ2rI24i3jFg"
      },
      "outputs": [],
      "source": [
        "class Encoder(tf.keras.Model):\n",
        "  def __init__(self, vocab_size, embedding_dim, encoder_size):\n",
        "    super(Encoder, self).__init__()\n",
        "    self.embedding_dim = embedding_dim\n",
        "    self.encoder_size = encoder_size\n",
        "    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
        "    self.gru = gru(encoder_size)\n",
        "\n",
        "  def call(self, x, hidden):\n",
        "    x = self.embedding(x)\n",
        "    output, state = self.gru(x, initial_state=hidden)\n",
        "    return output, state\n",
        "\n",
        "  def initial_hidden_state(self, batch_size):\n",
        "    return tf.zeros((batch_size, self.encoder_size))\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "kvzxABcg91RS"
      },
      "source": [
        "\n",
        "For the decoder, we're using *Bahdanau attention*. Here are the equations that are implemented:\n",
        "\n",
        "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_0.jpg\" alt=\"attention equation 0\" width=\"800\"\u003e\n",
        "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_1.jpg\" alt=\"attention equation 1\" width=\"800\"\u003e\n",
        "\n",
        "Lets decide on notation before writing the simplified form:\n",
        "\n",
        "* FC = Fully connected (dense) layer\n",
        "* EO = Encoder output\n",
        "* H = hidden state\n",
        "* X = input to the decoder\n",
        "\n",
        "And the pseudo-code:\n",
        "\n",
        "* `score = FC(tanh(FC(EO) + FC(H)))`\n",
        "* `attention weights = softmax(score, axis = 1)`. Softmax by default is applied on the last axis but here we want to apply it on the *1st axis*, since the shape of score is *(batch_size, max_length, hidden_size)*. `Max_length` is the length of our input. Since we are trying to assign a weight to each input, softmax should be applied on that axis.\n",
        "* `context vector = sum(attention weights * EO, axis = 1)`. Same reason as above for choosing axis as 1.\n",
        "* `embedding output` = The input to the decoder X is passed through an embedding layer.\n",
        "* `merged vector = concat(embedding output, context vector)`\n",
        "* This merged vector is then given to the GRU\n",
        "  \n",
        "The shapes of all the vectors at each step have been specified in the comments in the code:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "yJ_B3mhW3jFk"
      },
      "outputs": [],
      "source": [
        "class BahdanauAttention(tf.keras.Model):\n",
        "  def __init__(self, units):\n",
        "    super(BahdanauAttention, self).__init__()\n",
        "    self.W1 = tf.keras.layers.Dense(units)\n",
        "    self.W2 = tf.keras.layers.Dense(units)\n",
        "    self.V = tf.keras.layers.Dense(1)\n",
        "  \n",
        "  def call(self, hidden_state, enc_output):\n",
        "    # enc_output shape = (batch_size, max_length, hidden_size)\n",
        "\n",
        "    # (batch_size, hidden_size) -\u003e (batch_size, 1, hidden_size)\n",
        "    hidden_with_time = tf.expand_dims(hidden_state, 1)\n",
        "    \n",
        "    # score shape == (batch_size, max_length, 1)\n",
        "    score = self.V(tf.nn.tanh(self.W1(enc_output) + self.W2(hidden_with_time)))\n",
        "    # attention_weights shape == (batch_size, max_length, 1)\n",
        "    attention_weights = tf.nn.softmax(score, axis=1)\n",
        "\n",
        "    # context_vector shape after sum = (batch_size, hidden_size)\n",
        "    context_vector = attention_weights * enc_output\n",
        "    context_vector = tf.reduce_sum(context_vector, axis=1)\n",
        "    \n",
        "    return context_vector, attention_weights\n",
        "\n",
        "\n",
        "class Decoder(tf.keras.Model):\n",
        "  def __init__(self, vocab_size, embedding_dim, decoder_size):\n",
        "    super(Decoder, self).__init__()\n",
        "    self.vocab_size = vocab_size\n",
        "    self.embedding_dim = embedding_dim\n",
        "    self.decoder_size = decoder_size\n",
        "    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
        "    self.gru = gru(decoder_size)\n",
        "    self.fc = tf.keras.layers.Dense(vocab_size)\n",
        "    self.attention = BahdanauAttention(decoder_size)\n",
        "\n",
        "  def call(self, x, hidden, enc_output):\n",
        "    context_vector, attention_weights = self.attention(hidden, enc_output)\n",
        "\n",
        "    # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n",
        "    x = self.embedding(x)\n",
        "\n",
        "    # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n",
        "    x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n",
        "\n",
        "    # passing the concatenated vector to the GRU\n",
        "    output, state = self.gru(x)\n",
        "\n",
        "    # output shape == (batch_size, hidden_size)\n",
        "    output = tf.reshape(output, (-1, output.shape[2]))\n",
        "\n",
        "    # output shape == (batch_size, vocab)\n",
        "    x = self.fc(output)\n",
        "\n",
        "    return x, state, attention_weights\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ErMEwyflPSaJ"
      },
      "source": [
        "## Define a translate function\n",
        "\n",
        "Now, let's put the encoder and decoder halves together. The encoder step is fairly straightforward; we'll just reuse Keras's dynamic unroll. For the decoder, we have to make some choices about how to feed the decoder RNN. Overall the process goes as follows:\n",
        "\n",
        "1. Pass the *input* through the *encoder* which return *encoder output* and the *encoder hidden state*.\n",
        "2. The encoder output, encoder hidden state and the \u0026lt;START\u0026gt; token is passed to the decoder.\n",
        "3. The decoder returns the *predictions* and the *decoder hidden state*.\n",
        "4. The encoder output, hidden state and next token is then fed back into the decoder repeatedly. This has two different behaviors under training and inference:\n",
        "  - during training, we use *teacher forcing*, where the correct next token is fed into the decoder, regardless of what the decoder emitted.\n",
        "  - during inference, we use `tf.argmax(predictions)` to select the most likely continuation and feed it back into the decoder. Another strategy that yields more robust results is called *beam search*.\n",
        "5. Repeat step 4 until either the decoder emits an \u0026lt;END\u0026gt; token, indicating that it's done translating, or we run into a hardcoded length limit. \n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "PqfjRVuRPSaK"
      },
      "outputs": [],
      "source": [
        "class NmtTranslator(tf.keras.Model):\n",
        "  def __init__(self, encoder, decoder, start_token_id, end_token_id):\n",
        "    super(NmtTranslator, self).__init__()\n",
        "    self.encoder = encoder\n",
        "    self.decoder = decoder\n",
        "    # (The token_id should match the decoder's language.)\n",
        "    # Uses start_token_id to initialize the decoder.\n",
        "    self.start_token_id = tf.constant(start_token_id)\n",
        "    # Check for sequence completion using this token_id\n",
        "    self.end_token_id = tf.constant(end_token_id)\n",
        "\n",
        "\n",
        "  @tf.function \n",
        "  def call(self, inp, target=None, max_output_length=MAX_OUTPUT_LENGTH):\n",
        "    '''Translate an input.\n",
        "\n",
        "    If target is provided, teacher forcing is used to generate the translation.\n",
        "    '''\n",
        "    batch_size = inp.shape[0]\n",
        "    hidden = self.encoder.initial_hidden_state(batch_size)\n",
        "\n",
        "    enc_output, enc_hidden = self.encoder(inp, hidden)\n",
        "    dec_hidden = enc_hidden\n",
        "\n",
        "    if target is not None:\n",
        "      output_length = target.shape[1]\n",
        "    else:\n",
        "      output_length = max_output_length\n",
        "\n",
        "    predictions_array = tf.TensorArray(tf.float32, size=output_length - 1)\n",
        "    attention_array = tf.TensorArray(tf.float32, size=output_length - 1)\n",
        "    # Feed \u003cSTART\u003e token to start decoder.\n",
        "    dec_input = tf.cast([self.start_token_id] * batch_size, tf.int32)\n",
        "    # Keep track of which sequences have emitted an \u003cEND\u003e token\n",
        "    is_done = tf.zeros([batch_size], dtype=tf.bool)\n",
        "\n",
        "    for i in tf.range(output_length - 1):\n",
        "      dec_input = tf.expand_dims(dec_input, 1)\n",
        "      predictions, dec_hidden, attention_weights = self.decoder(dec_input, dec_hidden, enc_output)\n",
        "      predictions = tf.where(is_done, tf.zeros_like(predictions), predictions)\n",
        "      \n",
        "      # Write predictions/attention for later visualization.\n",
        "      predictions_array = predictions_array.write(i, predictions)\n",
        "      attention_array = attention_array.write(i, attention_weights)\n",
        "\n",
        "      # Decide what to pass into the next iteration of the decoder.\n",
        "      if target is not None:\n",
        "        # if target is known, use teacher forcing\n",
        "        dec_input = target[:, i + 1]\n",
        "      else:\n",
        "        # Otherwise, pick the most likely continuation\n",
        "        dec_input = tf.argmax(predictions, axis=1, output_type=tf.int32)\n",
        "\n",
        "      # Figure out which sentences just completed.\n",
        "      is_done = tf.logical_or(is_done, tf.equal(dec_input, self.end_token_id))\n",
        "      # Exit early if all our sentences are done.\n",
        "      if tf.reduce_all(is_done):\n",
        "        break\n",
        "\n",
        "    # [time, batch, predictions] -\u003e [batch, time, predictions]\n",
        "    return tf.transpose(predictions_array.stack(), [1, 0, 2]), tf.transpose(attention_array.stack(), [1, 0, 2, 3])\n",
        "    \n",
        "  \n",
        "    "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "_ch_71VbIRfK"
      },
      "source": [
        "## Define the loss function\n",
        "\n",
        "Our loss function is a word-for-word comparison between true answer and model prediction.\n",
        "  \n",
        "    real = [\u003cstart\u003e, 'This', 'is', 'the', 'correct', 'answer', '.', '\u003cend\u003e', '\u003coov\u003e']\n",
        "    pred = ['This', 'is', 'what', 'the', 'model', 'emitted', '.', '\u003cend\u003e']\n",
        "\n",
        "results in comparing\n",
        "\n",
        "    This/This, is/is, the/what, correct/the, answer/model, ./emitted, \u003cend\u003e/.\n",
        "and ignoring the rest of the prediction.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "WmTHr5iV3jFr"
      },
      "outputs": [],
      "source": [
        "def loss_fn(real, pred):\n",
        "  # The prediction doesn't include the \u003cstart\u003e token.\n",
        "  real = real[:, 1:]\n",
        "  # Cut down the prediction to the correct shape (We ignore extra words).\n",
        "  pred = pred[:, :real.shape[1]]\n",
        "  # If real == \u003cOOV\u003e, then mask out the loss.\n",
        "  mask = 1 - np.equal(real, 0)\n",
        "  loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n",
        "\n",
        "  # Sum loss over the time dimension, but average it over the batch dimension.\n",
        "  return tf.reduce_mean(tf.reduce_sum(loss_, axis=1))\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "DMVWzzsfNl4e"
      },
      "source": [
        "## Configure model directory\n",
        "\n",
        "We'll use one directory to save all of our relevant artifacts (summary logs, checkpoints, SavedModel exports, etc.)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Zj8bXQTgNwrF"
      },
      "outputs": [],
      "source": [
        "# Where to save checkpoints, tensorboard summaries, etc.\n",
        "MODEL_DIR = '/tmp/tensorflow/nmt_attention'\n",
        "\n",
        "def apply_clean():\n",
        "  if tf.io.gfile.exists(MODEL_DIR):\n",
        "    print('Removing existing model dir: {}'.format(MODEL_DIR))\n",
        "    tf.io.gfile.rmtree(MODEL_DIR)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "mO2d7e6gTlRA"
      },
      "outputs": [],
      "source": [
        "# Optional: remove existing data\n",
        "apply_clean()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "rlR-g2hR5Hl0"
      },
      "outputs": [],
      "source": [
        "# Summary writers\n",
        "train_summary_writer = tf.summary.create_file_writer(\n",
        "  os.path.join(MODEL_DIR, 'summaries', 'train'), flush_millis=10000)\n",
        "test_summary_writer = tf.summary.create_file_writer(\n",
        "  os.path.join(MODEL_DIR, 'summaries', 'eval'), flush_millis=10000, name='test')\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "LttA5h8C8yOU"
      },
      "outputs": [],
      "source": [
        "# Set up all stateful objects\n",
        "encoder = Encoder(len(english_tokenizer.word_index) + 1, EMBEDDING_DIM, ENCODER_SIZE)\n",
        "decoder = Decoder(len(spanish_tokenizer.word_index) + 1, EMBEDDING_DIM, DECODER_SIZE)\n",
        "start_token_id = spanish_tokenizer.word_index[START_TOKEN]\n",
        "end_token_id = spanish_tokenizer.word_index[END_TOKEN]\n",
        "model = NmtTranslator(encoder, decoder, start_token_id, end_token_id)\n",
        "\n",
        "# TODO(brianklee): Investigate whether Adam defaults have changed and whether it affects training.\n",
        "optimizer = tf.keras.optimizers.Adam(epsilon=1e-8)#   tf.keras.optimizers.SGD(learning_rate=0.01)#Adam()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "JH1TiGS5PSak"
      },
      "outputs": [],
      "source": [
        "# Checkpoints\n",
        "checkpoint_dir = os.path.join(MODEL_DIR, 'checkpoints')\n",
        "checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')\n",
        "checkpoint = tf.train.Checkpoint(\n",
        "    encoder=encoder, decoder=decoder, optimizer=optimizer)\n",
        "# Restore variables on creation if a checkpoint exists.\n",
        "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "ssRahnCqAVe6"
      },
      "outputs": [],
      "source": [
        "# SavedModel exports\n",
        "export_path = os.path.join(MODEL_DIR, 'export')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "qtdJxgsFZdHi"
      },
      "source": [
        "# Visualize the model's output\n",
        "\n",
        "Let's visualize our model's output. (It hasn't been trained yet, so it will output gibberish.)\n",
        "\n",
        "We'll use this visualization to check on the model's progress."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "3fGS9Jd3Zsai"
      },
      "outputs": [],
      "source": [
        "def plot_attention(attention, sentence, predicted_sentence):\n",
        "    fig = plt.figure(figsize=(10,10))\n",
        "    ax = fig.add_subplot(1, 1, 1)\n",
        "    ax.matshow(attention, cmap='viridis')\n",
        "    \n",
        "    fontdict = {'fontsize': 14}\n",
        "    \n",
        "    ax.set_xticklabels([''] + sentence.split(), fontdict=fontdict, rotation=90)\n",
        "    ax.set_yticklabels([''] + predicted_sentence.split(), fontdict=fontdict)\n",
        "    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n",
        "    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n",
        "\n",
        "    plt.show()\n",
        "\n",
        "def ints_to_words(tokenizer, ints):\n",
        "    return ' '.join(tokenizer.index_word[int(i)] if int(i) != 0 else '\u003cOOV\u003e' for i in ints)\n",
        "  \n",
        "def sentence_to_ints(tokenizer, sentence):\n",
        "    sentence = preprocess_sentence(sentence)\n",
        "    return tf.constant(tokenizer.texts_to_sequences([sentence])[0])\n",
        "\n",
        "def translate_and_plot_ints(model, english_tokenizer, spanish_tokenizer, ints, target_ints=None):\n",
        "    \"\"\"Run translation on a sentence and plot an attention matrix.\n",
        "    \n",
        "    Sentence should be passed in as list of integers.\n",
        "    \"\"\"\n",
        "    ints = tf.expand_dims(ints, 0)\n",
        "    predictions, attention = model(ints)\n",
        "    prediction_ids = tf.squeeze(tf.argmax(predictions, axis=-1))\n",
        "    attention = tf.squeeze(attention)\n",
        "    sentence = ints_to_words(english_tokenizer, ints[0])\n",
        "    predicted_sentence = ints_to_words(spanish_tokenizer, prediction_ids)\n",
        "    print(u'Input: {}'.format(sentence))\n",
        "    print(u'Predicted translation: {}'.format(predicted_sentence))\n",
        "    if target_ints is not None:\n",
        "      print(u'Correct translation: {}'.format(ints_to_words(spanish_tokenizer, target_ints)))\n",
        "    plot_attention(attention, sentence, predicted_sentence)    \n",
        "\n",
        "def translate_and_plot_words(model, english_tokenizer, spanish_tokenizer, sentence, target_sentence=None):\n",
        "    \"\"\"Same as translate_and_plot_ints, but pass in a sentence as a string.\"\"\"\n",
        "    english_ints = sentence_to_ints(english_tokenizer, sentence)\n",
        "    spanish_ints = sentence_to_ints(spanish_tokenizer, target_sentence) if target_sentence is not None else None\n",
        "    translate_and_plot_ints(model, english_tokenizer, spanish_tokenizer, english_ints, target_ints=spanish_ints)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "sozn-RRBZzoa"
      },
      "outputs": [],
      "source": [
        "translate_and_plot_words(model, english_tokenizer, spanish_tokenizer, u\"it's really cold here\", u'hace mucho frio aqui')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tNdYJ8igFHTt"
      },
      "source": [
        "# Train the model\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "z5wVVaWCY8nf"
      },
      "outputs": [],
      "source": [
        "def train(model, optimizer, dataset):\n",
        "  \"\"\"Trains model on `dataset` using `optimizer`.\"\"\"\n",
        "  start = time.time()\n",
        "  avg_loss = tf.keras.metrics.Mean('loss', dtype=tf.float32)\n",
        "  for inp, target in dataset:\n",
        "    with tf.GradientTape() as tape:\n",
        "      predictions, _ = model(inp, target=target)\n",
        "      loss = loss_fn(target, predictions)\n",
        "\n",
        "    avg_loss(loss)\n",
        "    gradients = tape.gradient(loss, model.trainable_variables)\n",
        "    optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
        "    if tf.equal(optimizer.iterations % 10, 0):\n",
        "      tf.summary.scalar('loss', avg_loss.result(), step=optimizer.iterations)\n",
        "      avg_loss.reset_states()\n",
        "      rate = 10 / (time.time() - start)\n",
        "      print('Step #%d\\tLoss: %.6f (%.2f steps/sec)' % (optimizer.iterations, loss, rate))\n",
        "      start = time.time()\n",
        "    if tf.equal(optimizer.iterations % 100, 0):\n",
        "#       translate_and_plot_words(model, english_index, spanish_index, u\"it's really cold here.\", u'hace mucho frio aqui.')\n",
        "      translate_and_plot_ints(model, english_tokenizer, spanish_tokenizer, inp[0], target[0])\n",
        "\n",
        "def test(model, dataset, step_num):\n",
        "  \"\"\"Perform an evaluation of `model` on the examples from `dataset`.\"\"\"\n",
        "  avg_loss = tf.keras.metrics.Mean('loss', dtype=tf.float32)\n",
        "  for inp, target in dataset:\n",
        "    predictions, _ = model(inp)\n",
        "    loss = loss_fn(target, predictions)\n",
        "    avg_loss(loss)\n",
        "\n",
        "  print('Model test set loss: {:0.4f}'.format(avg_loss.result()))\n",
        "  tf.summary.scalar('loss', avg_loss.result(), step=step_num)\n",
        "\n",
        "\n",
        "      "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "ddefjBMa3jF0",
        "scrolled": true
      },
      "outputs": [],
      "source": [
        "NUM_TRAIN_EPOCHS = 10\n",
        "for i in range(NUM_TRAIN_EPOCHS):\n",
        "  start = time.time()\n",
        "  with train_summary_writer.as_default():\n",
        "    train(model, optimizer, train_ds)\n",
        "  end = time.time()\n",
        "  print('\\nTrain time for epoch #{} ({} total steps): {}'.format(\n",
        "      i + 1, optimizer.iterations, end - start))\n",
        "  with test_summary_writer.as_default():\n",
        "    test(model, test_ds, optimizer.iterations)\n",
        "  checkpoint.save(checkpoint_prefix)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "yvqeRHw2PSaq"
      },
      "outputs": [],
      "source": [
        "# TODO(brianklee): This seems to be complaining about input shapes not being set?\n",
        "# tf.saved_model.save(model, export_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "RTe5P5ioMJwN"
      },
      "source": [
        "## Next steps\n",
        "\n",
        "* [Download a different dataset](http://www.manythings.org/anki/) to experiment with translations, for example, English to German, or English to French.\n",
        "* Experiment with training on a larger dataset, or using more epochs\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "5k_dwb31ZmMX"
      },
      "outputs": [],
      "source": [
        ""
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "",
        "kind": "local"
      },
      "name": "nmt_with_attention.ipynb",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1C4fpM7_7IL8ZzF7Gc5abywqQjeQNS2-U",
          "timestamp": 1527858391290
        },
        {
          "file_id": "1pExo6aUuw0S6MISFWoinfJv0Ftm9V4qv",
          "timestamp": 1527776041613
        }
      ],
      "toc_visible": true,
      "version": "0.3.2"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
